{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3",
   "language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "source": [
    "\n",
    "# Download the Unsplash dataset\n",
    "\n",
    "This notebook can be used to download all images from the Unsplash dataset: https://github.com/unsplash/datasets. There are two versions Lite (25000 images) and Full (2M images). For the Full one you will need to apply for access (see [here](https://unsplash.com/data)). This will allow you to run CLIP on the whole dataset yourself. \n",
    "\n",
    "Put the .TSV files in the folder `unsplash-dataset/full` or `unsplash-dataset/lite` or adjust the path in the cell below. "
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "dataset_version = \"lite\"  # either \"lite\" or \"full\"\n",
    "unsplash_dataset_path = Path(\"unsplash-dataset\") / dataset_version"
   ]
  },
  {
   "source": [
    "## Load the dataset\n",
    "\n",
    "The `photos.tsv000` contains metadata about the photos in the dataset, but not the photos themselves. We will use the URLs of the photos to download the actual images."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Photos in the dataset: 25000\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "\n",
    "# Read the photos table\n",
    "photos = pd.read_csv(unsplash_dataset_path / \"photos.tsv000\", sep='\\t', header=0)\n",
    "\n",
    "# Extract the IDs and the URLs of the photos\n",
    "photo_urls = photos[['photo_id', 'photo_image_url']].values.tolist()\n",
    "\n",
    "# Print some statistics\n",
    "print(f'Photos in the dataset: {len(photo_urls)}')"
   ]
  },
  {
   "source": [
    "The file name of each photo corresponds to its unique ID from Unsplash. We will download the photos in a reduced resolution (640 pixels width), because they are downscaled by CLIP anyway."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import urllib.request\n",
    "\n",
    "# Path where the photos will be downloaded\n",
    "photos_donwload_path = unsplash_dataset_path / \"photos\"\n",
    "\n",
    "# Function that downloads a single photo\n",
    "def download_photo(photo):\n",
    "    # Get the ID of the photo\n",
    "    photo_id = photo[0]\n",
    "\n",
    "    # Get the URL of the photo (setting the width to 640 pixels)\n",
    "    photo_url = photo[1] + \"?w=640\"\n",
    "\n",
    "    # Path where the photo will be stored\n",
    "    photo_path = photos_donwload_path / (photo_id + \".jpg\")\n",
    "\n",
    "    # Only download a photo if it doesn't exist\n",
    "    if not photo_path.exists():\n",
    "        try:\n",
    "            urllib.request.urlretrieve(photo_url, photo_path)\n",
    "        except:\n",
    "            # Catch the exception if the download fails for some reason\n",
    "            print(f\"Cannot download {photo_url}\")\n",
    "            pass"
   ]
  },
  {
   "source": [
    "Now the actual download! The download can be parallelized very well, so we will use a thread pool. You may need to tune the `threads_count` parameter to achieve the optimzal performance based on your Internet connection. For me even 128 worked quite well."
   ],
   "cell_type": "markdown",
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from multiprocessing.pool import ThreadPool\n",
    "\n",
    "# Create the thread pool\n",
    "threads_count = 16\n",
    "pool = ThreadPool(threads_count)\n",
    "\n",
    "# Start the download\n",
    "pool.map(download_photo, photo_urls)\n",
    "\n",
    "# Display some statistics\n",
    "display(f'Photos downloaded: {len(photos)}')"
   ]
  }
 ]
}